# Copyright 2025 - Pruna AI GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any

import torch
from torch.utils.data import DataLoader

from pruna.engine.pruna_model import PrunaModel
from pruna.engine.utils import device_to_string, split_device


class BaseMetric(ABC):
    """The base class for all Pruna metrics."""

    metric_name: str
    metric_units: str
    higher_is_better: bool
    runs_on: list[str] = ["cuda", "cpu", "mps"]

    @property
    def device(self) -> str:
        """Return the current device."""
        return getattr(self, "_device", "cuda")

    @device.setter
    def device(self, dvc: str | torch.device):
        """Validate and store the device.

        Parameters
        ----------
        value : str | torch.device
            The device to set.
        """
        # To prevent the user from setting an unsupported device.
        if not self.is_device_supported(dvc):
            raise ValueError(
                f"Metric {getattr(self, 'metric_name', self.__class__.__name__)} "
                f"doesn't support device '{dvc}'. Must be one of {self.runs_on}."
            )

        object.__setattr__(self, "_device", device_to_string(dvc))

        if hasattr(self, "metric"):  # For memory metrics that internally use the GPUMemoryStats.
            object.__setattr__(self.metric, "_device", device_to_string(dvc))

    @abstractmethod
    def compute(
        self,
        model: PrunaModel,
        dataloader: DataLoader,
    ) -> Any:
        """
        Compute the metric value.

        Parameters
        ----------
        model : PrunaModel
            The model to evaluate.
        dataloader : DataLoader
            The dataloader to use for the evaluation.

        Returns
        -------
        Any
            The computed metric value.
        """
        pass

    def is_device_supported(self, device: str | torch.device) -> bool:
        """
        Check if the metric is compatible with the device.

        Parameters
        ----------
        device : str | torch.device | None, optional
            The device to check. If None, the metric is compatible with all devices.

        Returns
        -------
        bool
            True if the metric is compatible with the device, False otherwise.
        """
        dvc, _ = split_device(device_to_string(device))
        return dvc in self.runs_on
